package server_test

import (
	"strings"
	"testing"
	"time"

	runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1"
	"github.com/rilldata/rill/runtime/pkg/activity"
	"github.com/rilldata/rill/runtime/pkg/ratelimit"
	"github.com/rilldata/rill/runtime/server"
	"github.com/rilldata/rill/runtime/testruntime"
	"github.com/stretchr/testify/require"
	"go.uber.org/zap"
	"google.golang.org/protobuf/types/known/structpb"
	"google.golang.org/protobuf/types/known/timestamppb"
)

func TestResolveTemplatedString_MetricsSQL(t *testing.T) {
	rt, instanceID := testruntime.NewInstanceWithOptions(t, testruntime.InstanceOptions{
		Files: map[string]string{
			"rill.yaml": "",
			"model1.sql": `
SELECT 'US' AS country, DATE '2024-01-01' AS order_date, 100 AS revenue, 5 AS orders
UNION ALL
SELECT 'UK' AS country, DATE '2024-01-15' AS order_date, 200 AS revenue, 10 AS orders
UNION ALL
SELECT 'CA' AS country, DATE '2024-02-01' AS order_date, 150 AS revenue, 15 AS orders
UNION ALL
SELECT 'CA' AS country, DATE '2024-02-15' AS order_date, 250 AS revenue, 20 AS orders
`,
			"mv1.yaml": `
type: metrics_view
version: 1
model: model1
timeseries: order_date
dimensions:
- column: country
- column: order_date
  name: order_date
measures:
- name: total_revenue
  expression: SUM(revenue)
- name: total_orders
  expression: SUM(orders)
`,
			"model2.sql": `
SELECT 'Electronics' AS category, 500 AS sales
UNION ALL
SELECT 'Clothing' AS category, 250 AS sales
UNION ALL
SELECT 'Food' AS category, 150 AS sales
`,
			"mv2.yaml": `
type: metrics_view
version: 1
model: model2
dimensions:
- column: category
measures:
- name: total_sales
  expression: SUM(sales)
`,
		},
	})
	testruntime.RequireReconcileState(t, rt, instanceID, 5, 0, 0)

	server, err := server.NewServer(t.Context(), &server.Options{}, rt, zap.NewNop(), ratelimit.NewNoop(), activity.NewNoopClient())
	require.NoError(t, err)

	tt := []struct {
		name                string
		body                string
		useFormatTokens     bool
		additionalWhere     map[string]*runtimev1.Expression
		additionalTimeRange *runtimev1.TimeRange
		expected            []string
		expectEqual         bool
	}{
		{
			name:            "WithFormatTokens",
			body:            `Total: {{ metrics_sql "SELECT total_revenue FROM mv1" }}`,
			useFormatTokens: true,
			additionalWhere: nil,
			expected:        []string{`__RILL__FORMAT__("mv1", "total_revenue", 700)`},
			expectEqual:     false,
		},
		{
			name: "MultipleQueries",
			body: `Revenue: {{ metrics_sql "SELECT total_revenue FROM mv1" }}
		Orders: {{ metrics_sql "SELECT total_orders FROM mv1" }}`,
			useFormatTokens: false,
			additionalWhere: nil,
			expected:        []string{"Revenue: 700", "Orders: 50"},
			expectEqual:     false,
		},
		{
			name:            "MultipleMetricsViewsWithDifferentFilters",
			body:            `Revenue: {{ metrics_sql "SELECT total_revenue FROM mv1" }}, Sales: {{ metrics_sql "SELECT total_sales FROM mv2" }}`,
			useFormatTokens: false,
			additionalWhere: map[string]*runtimev1.Expression{
				"mv1": &runtimev1.Expression{
					Expression: &runtimev1.Expression_Cond{
						Cond: &runtimev1.Condition{
							Op: runtimev1.Operation_OPERATION_IN,
							Exprs: []*runtimev1.Expression{
								{Expression: &runtimev1.Expression_Ident{Ident: "country"}},
								{
									Expression: &runtimev1.Expression_Val{
										Val: structpb.NewListValue(&structpb.ListValue{
											Values: []*structpb.Value{
												structpb.NewStringValue("US"),
												structpb.NewStringValue("UK"),
											},
										}),
									},
								},
							},
						},
					},
				},
				"mv2": &runtimev1.Expression{
					Expression: &runtimev1.Expression_Cond{
						Cond: &runtimev1.Condition{
							Op: runtimev1.Operation_OPERATION_EQ,
							Exprs: []*runtimev1.Expression{
								{Expression: &runtimev1.Expression_Ident{Ident: "category"}},
								{Expression: &runtimev1.Expression_Val{Val: structpb.NewStringValue("Electronics")}},
							},
						},
					},
				},
			},
			expected:    []string{"Revenue: 300, Sales: 500"},
			expectEqual: true,
		},
		{
			name:            "WithAdditionalTimeRange",
			body:            `Revenue: {{ metrics_sql "SELECT total_revenue FROM mv1" }}`,
			useFormatTokens: false,
			additionalWhere: nil,
			additionalTimeRange: &runtimev1.TimeRange{
				Start: timestamppb.New(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)),
				End:   timestamppb.New(time.Date(2024, 2, 1, 0, 0, 0, 0, time.UTC)),
			},
			expected:    []string{"Revenue: 300"},
			expectEqual: true,
		},
		{
			name:            "WithRefs",
			body:            `Revenue in US and UK from 2024-01-01 to 2024-02-01: {{ metrics_sql "SELECT total_revenue FROM {{ ref \"mv1\" }}" }}`,
			useFormatTokens: false,
			additionalWhere: map[string]*runtimev1.Expression{
				"mv1": &runtimev1.Expression{
					Expression: &runtimev1.Expression_Cond{
						Cond: &runtimev1.Condition{
							Op: runtimev1.Operation_OPERATION_IN,
							Exprs: []*runtimev1.Expression{
								{Expression: &runtimev1.Expression_Ident{Ident: "country"}},
								{
									Expression: &runtimev1.Expression_Val{
										Val: structpb.NewListValue(&structpb.ListValue{
											Values: []*structpb.Value{
												structpb.NewStringValue("US"),
												structpb.NewStringValue("UK"),
											},
										}),
									},
								},
							},
						},
					},
				},
			},
			additionalTimeRange: nil,
			expected:            []string{"Revenue in US and UK from 2024-01-01 to 2024-02-01: 300"},
			expectEqual:         true,
		},
		{
			name:            "WithMultipleRefs",
			body:            `{{ metrics_sql "SELECT total_revenue FROM {{ ref \"mv1\" }}" }} and {{ metrics_sql "SELECT total_sales FROM {{ ref \"mv2\" }}" }}`,
			useFormatTokens: false,
			additionalWhere: nil,
			expected:        []string{"700 and 900"},
			expectEqual:     true,
		},
	}

	for _, tc := range tt {
		t.Run(tc.name, func(t *testing.T) {
			res, err := server.ResolveTemplatedString(testCtx(), &runtimev1.ResolveTemplatedStringRequest{
				InstanceId:                   instanceID,
				Body:                         tc.body,
				UseFormatTokens:              tc.useFormatTokens,
				AdditionalWhereByMetricsView: tc.additionalWhere,
				AdditionalTimeRange:          tc.additionalTimeRange,
			})
			require.NoError(t, err)

			if tc.expectEqual {
				require.Equal(t, strings.Join(tc.expected, ""), res.Body)
			} else {
				for _, exp := range tc.expected {
					require.Contains(t, res.Body, exp)
				}
			}
		})
	}
}
